{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# <div style='color:#fe618e;font-weight:800;'>optuna库选择超参数</div>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nes_py.wrappers import JoypadSpace\n",
    "import gym_super_mario_bros\n",
    "from gym_super_mario_bros.actions import SIMPLE_MOVEMENT\n",
    "import time\n",
    "from matplotlib import pyplot as plt\n",
    "from gym.wrappers import GrayScaleObservation\n",
    "from stable_baselines3.common.monitor import Monitor\n",
    "from stable_baselines3.common.vec_env import DummyVecEnv\n",
    "from stable_baselines3.common.vec_env import VecFrameStack\n",
    "import os\n",
    "from stable_baselines3 import PPO\n",
    "\n",
    "from stable_baselines3.common.results_plotter import load_results, ts2xy\n",
    "import numpy as np\n",
    "from stable_baselines3.common.callbacks import BaseCallback\n",
    "\n",
    "import optuna\n",
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym_super_mario_bros.make('SuperMarioBros-v0')\n",
    "env = JoypadSpace(env, SIMPLE_MOVEMENT)\n",
    "\n",
    "monitor_dir = r'./monitor_log/'\n",
    "os.makedirs(monitor_dir,exist_ok=True)\n",
    "env = Monitor(env,monitor_dir)\n",
    "\n",
    "env = GrayScaleObservation(env,keep_dim=True)\n",
    "env = DummyVecEnv([lambda: env])\n",
    "env = VecFrameStack(env,4,channels_order='last')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def optimize_ppo(trial): \n",
    "    return {\n",
    "        'n_steps':trial.suggest_int('n_steps', 2048, 8192),\n",
    "        'gamma':trial.suggest_loguniform('gamma', 0.8, 0.9999),\n",
    "        'learning_rate':trial.suggest_loguniform('learning_rate', 1e-5, 1e-4),\n",
    "        'clip_range':trial.suggest_uniform('clip_range', 0.1, 0.4),\n",
    "        'gae_lambda':trial.suggest_uniform('gae_lambda', 0.8, 0.99)\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from stable_baselines3.common.evaluation import evaluate_policy\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def optimize_function(trial):\n",
    "    \n",
    "    try:\n",
    "        env = gym_super_mario_bros.make('SuperMarioBros-v0')\n",
    "        env = JoypadSpace(env, SIMPLE_MOVEMENT)\n",
    "\n",
    "        monitor_dir = r'./monitor_log/'\n",
    "        os.makedirs(monitor_dir,exist_ok=True)\n",
    "        env = Monitor(env,monitor_dir)\n",
    "\n",
    "        env = GrayScaleObservation(env,keep_dim=True)\n",
    "        env = DummyVecEnv([lambda: env])\n",
    "        env = VecFrameStack(env,4,channels_order='last')\n",
    "\n",
    "\n",
    "        model_params = optimize_ppo(trial) \n",
    "        # model_params = {\n",
    "        #     'n_steps':trial.suggest_int('n_steps', 2048, 8192),\n",
    "        #     'gamma':trial.suggest_loguniform('gamma', 0.8, 0.9999),\n",
    "        #     'learning_rate':trial.suggest_loguniform('learning_rate', 1e-5, 1e-4),\n",
    "        #     'clip_range':trial.suggest_uniform('clip_range', 0.1, 0.4),\n",
    "        #     'gae_lambda':trial.suggest_uniform('gae_lambda', 0.8, 0.99)\n",
    "        # }\n",
    "\n",
    "        tensorboard_log = r'./tensorboard_log/'\n",
    "        model = PPO(\"CnnPolicy\", env, verbose=0,tensorboard_log=tensorboard_log,**model_params)\n",
    "        model.learn(total_timesteps=1000)\n",
    "        # model.learn(total_timesteps=200000)\n",
    "    \n",
    "        # mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=5)\n",
    "        mean_reward, _ = evaluate_policy(model, env,n_eval_episodes=1)\n",
    "    \n",
    "    \n",
    "        env.close()\n",
    "    \n",
    "        save_model_dir = r'F:\\\\RL_Mario1\\\\'\n",
    "        os.makedirs(save_model_dir,exist_ok=True)\n",
    "        SAVE_PATH = os.path.join(save_model_dir, 'trial_{}_best_model'.format(trial.number))\n",
    "        model.save(SAVE_PATH)\n",
    "    \n",
    "        return mean_reward    \n",
    "\n",
    "    except Exception as e:\n",
    "        return -1000\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import optuna"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[I 2022-03-10 10:24:43,542]\u001b[0m A new study created in memory with name: no-name-bc01070f-8a31-4b43-90a9-69370b42bb39\u001b[0m\n",
      "D:\\software\\e_anaconda\\envs\\pytorch\\lib\\site-packages\\stable_baselines3\\ppo\\ppo.py:137: UserWarning: You have specified a mini-batch size of 64, but because the `RolloutBuffer` is of size `n_steps * n_envs = 3736`, after every 58 untruncated mini-batches, there will be a truncated mini-batch of size 24\n",
      "We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n",
      "Info: (n_steps=3736 and n_envs=1)\n",
      "  warnings.warn(\n",
      "D:\\software\\e_anaconda\\envs\\pytorch\\lib\\site-packages\\gym_super_mario_bros\\smb_env.py:148: RuntimeWarning: overflow encountered in ubyte_scalars\n",
      "  return (self.ram[0x86] - self.ram[0x071c]) % 256\n",
      "\u001b[32m[I 2022-03-10 10:28:37,340]\u001b[0m Trial 0 finished with value: -1260.0 and parameters: {'n_steps': 3736, 'gamma': 0.8113866241862995, 'learning_rate': 7.613128199159764e-05, 'clip_range': 0.29279385654405704, 'gae_lambda': 0.937024171211452}. Best is trial 0 with value: -1260.0.\u001b[0m\n",
      "D:\\software\\e_anaconda\\envs\\pytorch\\lib\\site-packages\\stable_baselines3\\ppo\\ppo.py:137: UserWarning: You have specified a mini-batch size of 64, but because the `RolloutBuffer` is of size `n_steps * n_envs = 2592`, after every 40 untruncated mini-batches, there will be a truncated mini-batch of size 32\n",
      "We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n",
      "Info: (n_steps=2592 and n_envs=1)\n",
      "  warnings.warn(\n",
      "D:\\software\\e_anaconda\\envs\\pytorch\\lib\\site-packages\\gym_super_mario_bros\\smb_env.py:148: RuntimeWarning: overflow encountered in ubyte_scalars\n",
      "  return (self.ram[0x86] - self.ram[0x071c]) % 256\n",
      "\u001b[32m[I 2022-03-10 10:29:31,488]\u001b[0m Trial 1 finished with value: 680.0 and parameters: {'n_steps': 2592, 'gamma': 0.8450658799536929, 'learning_rate': 2.4952570083224206e-05, 'clip_range': 0.3315680151889212, 'gae_lambda': 0.9606306727084878}. Best is trial 1 with value: 680.0.\u001b[0m\n",
      "D:\\software\\e_anaconda\\envs\\pytorch\\lib\\site-packages\\stable_baselines3\\ppo\\ppo.py:137: UserWarning: You have specified a mini-batch size of 64, but because the `RolloutBuffer` is of size `n_steps * n_envs = 7743`, after every 120 untruncated mini-batches, there will be a truncated mini-batch of size 63\n",
      "We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n",
      "Info: (n_steps=7743 and n_envs=1)\n",
      "  warnings.warn(\n",
      "D:\\software\\e_anaconda\\envs\\pytorch\\lib\\site-packages\\gym_super_mario_bros\\smb_env.py:148: RuntimeWarning: overflow encountered in ubyte_scalars\n",
      "  return (self.ram[0x86] - self.ram[0x071c]) % 256\n",
      "\u001b[32m[I 2022-03-10 10:32:03,598]\u001b[0m Trial 2 finished with value: 680.0 and parameters: {'n_steps': 7743, 'gamma': 0.8874971551801487, 'learning_rate': 1.3457059614838189e-05, 'clip_range': 0.1796126116316833, 'gae_lambda': 0.967313194894678}. Best is trial 1 with value: 680.0.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "study = optuna.create_study(direction='maximize')\n",
    "# study.optimize(optimize_function, n_trials=100)\n",
    "study.optimize(optimize_function, n_trials=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'n_steps': 2592,\n",
       " 'gamma': 0.8450658799536929,\n",
       " 'learning_rate': 2.4952570083224206e-05,\n",
       " 'clip_range': 0.3315680151889212,\n",
       " 'gae_lambda': 0.9606306727084878}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "study.best_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FrozenTrial(number=1, values=[680.0], datetime_start=datetime.datetime(2022, 3, 10, 10, 28, 37, 341786), datetime_complete=datetime.datetime(2022, 3, 10, 10, 29, 31, 487396), params={'n_steps': 2592, 'gamma': 0.8450658799536929, 'learning_rate': 2.4952570083224206e-05, 'clip_range': 0.3315680151889212, 'gae_lambda': 0.9606306727084878}, distributions={'n_steps': IntUniformDistribution(high=8192, low=2048, step=1), 'gamma': LogUniformDistribution(high=0.9999, low=0.8), 'learning_rate': LogUniformDistribution(high=0.0001, low=1e-05), 'clip_range': UniformDistribution(high=0.4, low=0.1), 'gae_lambda': UniformDistribution(high=0.99, low=0.8)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=1, state=TrialState.COMPLETE, value=None)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "study.best_trial"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "9465aae7e0ab1403d672807d1a0963d86dbda2f584fbe3054c36cf78311c6c77"
  },
  "kernelspec": {
   "display_name": "Python 3.8.11 ('pytorch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
